{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# C-host 代码生成简介"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "import set_env"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tvm\n",
    "import tvm.testing\n",
    "\n",
    "from tvm import te\n",
    "from tvm.contrib import utils\n",
    "from tvm.script import tir as T, ir as I\n",
    "\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## `add` c-host 代码生成"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "nn = 1024\n",
    "n = tvm.runtime.convert(nn)\n",
    "A = te.placeholder((n,), name=\"A\")\n",
    "B = te.placeholder((n,), name=\"B\")\n",
    "C = te.compute(A.shape, lambda *i: A(*i) + B(*i), name=\"C\")\n",
    "s = te.create_schedule(C.op)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "tags": [
     "hidden-output"
    ]
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "// tvm target: c -keys=cpu \n",
      "#define TVM_EXPORTS\n",
      "#include \"tvm/runtime/c_runtime_api.h\"\n",
      "#include \"tvm/runtime/c_backend_api.h\"\n",
      "#include <math.h>\n",
      "#include <stdbool.h>\n",
      "#ifdef __cplusplus\n",
      "extern \"C\"\n",
      "#endif\n",
      "TVM_DLL int32_t test_fadd(void* args, int32_t* arg_type_ids, int32_t num_args, void* out_ret_value, int32_t* out_ret_tcode, void* resource_handle);\n",
      "#ifdef __cplusplus\n",
      "extern \"C\"\n",
      "#endif\n",
      "TVM_DLL int32_t test_fadd(void* args, int32_t* arg_type_ids, int32_t num_args, void* out_ret_value, int32_t* out_ret_tcode, void* resource_handle) {\n",
      "  int32_t A_code = arg_type_ids[0];\n",
      "  int32_t B_code = arg_type_ids[1];\n",
      "  int32_t C_code = arg_type_ids[2];\n",
      "  void* A = (((TVMValue*)args)[0].v_handle);\n",
      "  void* B = (((TVMValue*)args)[1].v_handle);\n",
      "  void* C = (((TVMValue*)args)[2].v_handle);\n",
      "  void* test_fadd_A_shape = (((DLTensor*)A)[0].shape);\n",
      "  void* test_fadd_A_strides = (((DLTensor*)A)[0].strides);\n",
      "  int32_t dev_id = (((DLTensor*)A)[0].device.device_id);\n",
      "  void* A_1 = (((DLTensor*)A)[0].data);\n",
      "  void* test_fadd_B_shape = (((DLTensor*)B)[0].shape);\n",
      "  void* test_fadd_B_strides = (((DLTensor*)B)[0].strides);\n",
      "  void* B_1 = (((DLTensor*)B)[0].data);\n",
      "  void* test_fadd_C_shape = (((DLTensor*)C)[0].shape);\n",
      "  void* test_fadd_C_strides = (((DLTensor*)C)[0].strides);\n",
      "  void* C_1 = (((DLTensor*)C)[0].data);\n",
      "  if (!(test_fadd_A_strides == NULL)) {\n",
      "  }\n",
      "  if (!(test_fadd_B_strides == NULL)) {\n",
      "  }\n",
      "  if (!(test_fadd_C_strides == NULL)) {\n",
      "  }\n",
      "  for (int32_t i0 = 0; i0 < 1024; ++i0) {\n",
      "    ((float*)C_1)[i0] = (((float*)A_1)[i0] + ((float*)B_1)[i0]);\n",
      "  }\n",
      "  return 0;\n",
      "}\n",
      "\n",
      "// CodegenC: NOTE: Auto-generated entry function\n",
      "#ifdef __cplusplus\n",
      "extern \"C\"\n",
      "#endif\n",
      "TVM_DLL int32_t __tvm_main__(void* args, int* arg_type_ids, int num_args, void* out_ret_value, int* out_ret_tcode, void* resource_handle) {\n",
      "  return test_fadd(args, arg_type_ids, num_args, out_ret_value, out_ret_tcode, resource_handle);\n",
      "}\n",
      "\n"
     ]
    }
   ],
   "source": [
    "mhost = tvm.build(s, [A, B, C], \"c\", name=\"test_fadd\")\n",
    "print(mhost.get_source())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "temp = utils.tempdir()\n",
    "path_dso = temp.relpath(\"temp.so\")\n",
    "mhost.export_library(path_dso)\n",
    "m = tvm.runtime.load_module(path_dso)\n",
    "fadd = m[\"test_fadd\"]\n",
    "dev = tvm.cpu(0)\n",
    "# launch the kernel.\n",
    "n = nn\n",
    "a = tvm.nd.array(np.random.uniform(size=n).astype(A.dtype), dev)\n",
    "b = tvm.nd.array(np.random.uniform(size=n).astype(B.dtype), dev)\n",
    "c = tvm.nd.array(np.zeros(n, dtype=C.dtype), dev)\n",
    "fadd(a, b, c)\n",
    "tvm.testing.assert_allclose(c.numpy(), a.numpy() + b.numpy())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## `add_pipeline` c host 代码生成"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "nn = 1024\n",
    "n = tvm.runtime.convert(nn)\n",
    "A = te.placeholder((n,), name=\"A\")\n",
    "B = te.placeholder((n,), name=\"B\")\n",
    "AA = te.compute((n,), lambda *i: A(*i), name=\"A\")\n",
    "BB = te.compute((n,), lambda *i: B(*i), name=\"B\")\n",
    "T = te.compute(A.shape, lambda *i: AA(*i) + BB(*i), name=\"T\")\n",
    "C = te.compute(A.shape, lambda *i: T(*i), name=\"C\")\n",
    "s = te.create_schedule(C.op)\n",
    "xo, xi = s[C].split(C.op.axis[0], factor=4)\n",
    "xo1, xo2 = s[C].split(xo, factor=13)\n",
    "s[C].parallel(xo2)\n",
    "s[C].pragma(xo1, \"parallel_launch_point\")\n",
    "s[C].pragma(xo2, \"parallel_stride_pattern\")\n",
    "s[C].pragma(xo2, \"parallel_barrier_when_finish\")\n",
    "# FIXME(tvm-team): vector operators are not supported for codegen to C yet\n",
    "# s[C].vectorize(xi)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "def check_c():\n",
    "    # Specifically allow offset to test codepath when offset is available\n",
    "    Ab = tvm.tir.decl_buffer(\n",
    "        A.shape, A.dtype, elem_offset=te.size_var(\"Aoffset\"), offset_factor=8, name=\"A\"\n",
    "    )\n",
    "    binds = {A: Ab}\n",
    "    # BUILD and invoke the kernel.\n",
    "    f1 = tvm.lower(s, [A, B, C], name=\"test_fadd_pipeline\")\n",
    "    mhost = tvm.build(f1, target=\"c\")\n",
    "\n",
    "    temp = utils.tempdir()\n",
    "    path_dso = temp.relpath(\"temp.so\")\n",
    "    mhost.export_library(path_dso)\n",
    "    m = tvm.runtime.load_module(path_dso)\n",
    "    fadd = m[\"test_fadd_pipeline\"]\n",
    "    dev = tvm.cpu(0)\n",
    "    # launch the kernel.\n",
    "    n = nn\n",
    "    a = tvm.nd.array(np.random.uniform(size=n).astype(A.dtype), dev)\n",
    "    b = tvm.nd.array(np.random.uniform(size=n).astype(B.dtype), dev)\n",
    "    c = tvm.nd.array(np.zeros(n, dtype=C.dtype), dev)\n",
    "    fadd(a, b, c)\n",
    "    tvm.testing.assert_allclose(c.numpy(), a.numpy() + b.numpy())\n",
    "\n",
    "check_c()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## `tir.reinterpret` c host"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "nn = 1024\n",
    "n = tvm.runtime.convert(nn)\n",
    "A = te.placeholder((n,), name=\"A\", dtype=\"int32\")\n",
    "B = te.compute(\n",
    "    A.shape, lambda *i: tvm.tir.call_intrin(\"float32\", \"tir.reinterpret\", 2 + A(*i)), name=\"B\"\n",
    ")\n",
    "s = te.create_schedule(B.op)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "def check_c():\n",
    "    mhost = tvm.build(s, [A, B], \"c\", name=\"test_reinterpret\")\n",
    "    temp = utils.tempdir()\n",
    "    path_dso = temp.relpath(\"temp.so\")\n",
    "    mhost.export_library(path_dso)\n",
    "    m = tvm.runtime.load_module(path_dso)\n",
    "    fadd = m[\"test_reinterpret\"]\n",
    "    dev = tvm.cpu(0)\n",
    "    n = nn\n",
    "    a = tvm.nd.array(np.random.randint(-(2**30), 2**30, size=n).astype(A.dtype), dev)\n",
    "    b = tvm.nd.array(np.zeros(n, dtype=B.dtype), dev)\n",
    "    fadd(a, b)\n",
    "    tvm.testing.assert_allclose(b.numpy(), (2 + a.numpy()).view(\"float32\"))\n",
    "\n",
    "check_c()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## `tir.ceil` c host"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def test_ceil():\n",
    "    nn = 1024\n",
    "    n = tvm.runtime.convert(nn)\n",
    "    A = te.placeholder((n,), name=\"A\", dtype=\"float32\")\n",
    "    B = te.compute(A.shape, lambda *i: tvm.tir.call_intrin(\"float32\", \"tir.ceil\", A(*i)), name=\"B\")\n",
    "    s = te.create_schedule(B.op)\n",
    "\n",
    "    def check_c():\n",
    "        mhost = tvm.build(s, [A, B], \"c\", name=\"test_ceil\")\n",
    "        temp = utils.tempdir()\n",
    "        path_dso = temp.relpath(\"temp.so\")\n",
    "        mhost.export_library(path_dso)\n",
    "        m = tvm.runtime.load_module(path_dso)\n",
    "        fceil = m[\"test_ceil\"]\n",
    "        dev = tvm.cpu(0)\n",
    "        n = nn\n",
    "        a = tvm.nd.array(np.random.rand(n).astype(A.dtype), dev)\n",
    "        b = tvm.nd.array(np.zeros(n, dtype=B.dtype), dev)\n",
    "        fceil(a, b)\n",
    "        tvm.testing.assert_allclose(b.numpy(), (np.ceil(a.numpy()).view(\"float32\")))\n",
    "\n",
    "    check_c()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## `tir.floor` c host"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def test_floor():\n",
    "    nn = 1024\n",
    "    n = tvm.runtime.convert(nn)\n",
    "    A = te.placeholder((n,), name=\"A\", dtype=\"float32\")\n",
    "    B = te.compute(A.shape, lambda *i: tvm.tir.call_intrin(\"float32\", \"tir.floor\", A(*i)), name=\"B\")\n",
    "    s = te.create_schedule(B.op)\n",
    "\n",
    "    def check_c():\n",
    "        mhost = tvm.build(s, [A, B], \"c\", name=\"test_floor\")\n",
    "        temp = utils.tempdir()\n",
    "        path_dso = temp.relpath(\"temp.so\")\n",
    "        mhost.export_library(path_dso)\n",
    "        m = tvm.runtime.load_module(path_dso)\n",
    "        ffloor = m[\"test_floor\"]\n",
    "        dev = tvm.cpu(0)\n",
    "        n = nn\n",
    "        a = tvm.nd.array(np.random.rand(n).astype(A.dtype), dev)\n",
    "        b = tvm.nd.array(np.zeros(n, dtype=B.dtype), dev)\n",
    "        ffloor(a, b)\n",
    "        tvm.testing.assert_allclose(b.numpy(), (np.floor(a.numpy()).view(\"float32\")))\n",
    "\n",
    "    check_c()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## `tir.round` c host"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def test_round():\n",
    "    nn = 1024\n",
    "    n = tvm.runtime.convert(nn)\n",
    "    A = te.placeholder((n,), name=\"A\", dtype=\"float32\")\n",
    "    B = te.compute(A.shape, lambda *i: tvm.tir.call_intrin(\"float32\", \"tir.round\", A(*i)), name=\"B\")\n",
    "    s = te.create_schedule(B.op)\n",
    "\n",
    "    def check_c():\n",
    "        mhost = tvm.build(s, [A, B], \"c\", name=\"test_round\")\n",
    "        temp = utils.tempdir()\n",
    "        path_dso = temp.relpath(\"temp.so\")\n",
    "        mhost.export_library(path_dso)\n",
    "        m = tvm.runtime.load_module(path_dso)\n",
    "        fround = m[\"test_round\"]\n",
    "        dev = tvm.cpu(0)\n",
    "        n = nn\n",
    "        a = tvm.nd.array(np.random.rand(n).astype(A.dtype), dev)\n",
    "        b = tvm.nd.array(np.zeros(n, dtype=B.dtype), dev)\n",
    "        fround(a, b)\n",
    "        tvm.testing.assert_allclose(b.numpy(), (np.round(a.numpy()).view(\"float32\")))\n",
    "\n",
    "    check_c()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## subroutine c host"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def test_subroutine_call():\n",
    "    @I.ir_module\n",
    "    class mod:\n",
    "        @T.prim_func\n",
    "        def main(A: T.Buffer(1, dtype=\"float32\")):\n",
    "            mod.subroutine(A.data)\n",
    "\n",
    "        @T.prim_func(private=True)\n",
    "        def subroutine(A_data: T.handle(\"float32\")):\n",
    "            A = T.decl_buffer(1, dtype=\"float32\", data=A_data)\n",
    "            A[0] = 42.0\n",
    "\n",
    "    built = tvm.build(mod, target=\"c\")\n",
    "\n",
    "    func_names = list(built[\"get_func_names\"]())\n",
    "    assert (\n",
    "        \"main\" in func_names\n",
    "    ), \"Externally exposed functions should be listed in available functions.\"\n",
    "    assert (\n",
    "        \"subroutine\" not in func_names\n",
    "    ), \"Internal function should not be listed in available functions.\"\n",
    "\n",
    "    source = built.get_source()\n",
    "    assert (\n",
    "        source.count(\"main(void*\") == 2\n",
    "    ), \"Expected two occurrences, for forward-declaration and definition\"\n",
    "    assert (\n",
    "        source.count(\"subroutine(float*\") == 2\n",
    "    ), \"Expected two occurrences, for forward-declaration and definition\"\n",
    "    assert (\n",
    "        source.count(\"subroutine(\") == 3\n",
    "    ), \"Expected three occurrences, for forward-declaration, definition, and call from main.\"\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## {func}`tvm.tir.call_intrin` c host"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "def test_call_packed():\n",
    "    def fake_func(fname=\"fake.func\"):\n",
    "        ib = tvm.tir.ir_builder.create()\n",
    "        A = ib.pointer(\"float32\", name=\"A\")\n",
    "        fake_func1 = tvm.tir.call_packed(fname, A[0])\n",
    "\n",
    "        ib.emit(fake_func1)\n",
    "        body = ib.get()\n",
    "        return A, body\n",
    "\n",
    "    def check_global_packed_func():\n",
    "        fname = \"fake.func\"\n",
    "        A, body = fake_func(fname)\n",
    "        func1 = tvm.tir.PrimFunc([A], body).with_attr(\"global_symbol\", \"func1\")\n",
    "        B, body = fake_func()\n",
    "        func2 = tvm.tir.PrimFunc([B], body).with_attr(\"global_symbol\", \"func2\")\n",
    "        mod = tvm.IRModule({\"fake_func1\": func1, \"fake_func2\": func2})\n",
    "        fcode = tvm.build(mod, None, \"c\")\n",
    "        src = fcode.get_source()\n",
    "\n",
    "        # there are two locations calling the packed func\n",
    "        assert src.count(fname) == 2\n",
    "\n",
    "        suffix = \"_packed\"\n",
    "        packed_func_name = fname + suffix\n",
    "        # func name will be standardized by GetUniqueName and not exists anymore\n",
    "        assert src.find(packed_func_name) == -1\n",
    "\n",
    "        packed_func_real_name = \"_\".join(fname.split(\".\")) + suffix\n",
    "        func_declaration = \"static void* %s = NULL;\" % packed_func_real_name\n",
    "        # src only has 1 valid declaration\n",
    "        assert src.count(func_declaration) == 1\n",
    "\n",
    "    check_global_packed_func()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "py312x",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
